-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ORTModule GraphTransitionManager #19007
Conversation
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Outdated
Show resolved
Hide resolved
## Dependency #19007 ## ORTModule memory efficient gradient management Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with #8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter. This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly. ## Design Check the details from https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ ## Convergence Validation: ![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe) differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2) ## TODO Consolidate the logic with Stage3's similar logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late response.
orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Show resolved
Hide resolved
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Outdated
Show resolved
Hide resolved
…pengwa/refactor_io
…pengwa/refactor_io
…pengwa/refactor_io
…pengwa/refactor_io
…pengwa/refactor_io
construct_inputs and restore_outputs can probably be done by calling tree_flatten and tree_unflatten in https://github.com/pytorch/pytorch/blob/15bd81bfafa86fec9d675e7f071c867c852ebe8f/torch/utils/_pytree.py#L799. |
…pengwa/refactor_io
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py
Fixed
Show fixed
Hide fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now. Thanks.
Thanks @wschin. |
Problem
Currently, the codebase contains some logics pertaining to model re-export checks and graph_builder reinitialization checks. Ideally, these operations should function akin to a state machine. However, upon inspecting the implementation, it becomes apparent that certain states are checked or set in various scattered locations. This fragmentation makes it challenging to comprehend when a re-export or re-initialization will be triggered. For optimal clarity and maintainability, it is advisable to consolidate these states into a cohesive component, rather than dispersing them within the current graph execution manager.
Furthermore, the process of model exports and post-export processing for stage 3 support or memory-efficient gradient management introduces considerable complexity. To enhance the codebase's structure, it would be beneficial to extract these intricate functionalities into a dedicated component, divorcing them from the current graph execution manager.
As part of the effort to improve the codebase, it's essential to address inconsistencies in handling input/output flatten/unflatten operations. Currently, there are several functions performing these operations recursively, each with slightly different implementations. This inconsistency leads to varying support for input/output data types and structures in different parts of the code. To rectify this, the proposed pull request simplifies these operations into a set of primitive functions, ensuring uniformity. This not only streamlines the code but also facilitates the maintenance of consistency when introducing bug fixes or supporting new data types. One thing to mention here: input output handling is deeply bound to the graph transition mentioned above, so it is difficult to make this change separately.
While acknowledging the complexity of these logics, it is reassuring that the codebase benefits from an extensive suite of unit tests that cover all possible branches. Despite the intricacies, ensuring the passage of all tests has been a time-intensive but necessary aspect of this development effort.
Design
Introduce
GraphTransitionManager
and put all model export and post-export processing logics in it.PostExportProcessedModelInfo
, which contains all the information we need, to pass to ORT to build gradient graph (currently we do the same for training or evaluating, but ideally we should not do it for evaluating, let's keep this behavior as it is now, and make the change later).The
GraphTransitionManager
instance is a property ofGraphExecutionManager
(e.g.TrainingManager
or ``InferenceManager),self._graph_transition_manager._post_export_processed_model_info.construct_inputs
to construct the list of inputs used for ORT runs.self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)
to restore the outputs in original PyTorch output structure.Motivation and Context